import torch
import numpy as np
from matplotlib import pyplot as plt
import torch.nn as nn

import scipy.io.wavfile
import scipy.signal

import glob

RATE = 22050
NPERSEG = 256
NOVERLAP = 256-32

def stft(x):
    return scipy.signal.stft(x, nperseg=NPERSEG, noverlap=NOVERLAP, return_onesided=True)[2]

def istft(S):
    return scipy.signal.istft(S, nperseg=NPERSEG, noverlap=NOVERLAP, input_onesided=True)[1]

def invert_spectrogram(S):
    x = istft(S)
    for i in range(100):
        x = istft(S * np.exp(np.angle(stft(x)) * 1j))
    return x

zeros = np.zeros(RATE)
S = stft(zeros)
SPECTROGRAM_SIZE_F, SPECTROGRAM_SIZE_T = S.real.shape
SPECTROGRAM_SIZE = SPECTROGRAM_SIZE_F * SPECTROGRAM_SIZE_T
SPECTROGRAM_CHANNELS = 1
FEATURE_SIZE = SPECTROGRAM_SIZE * SPECTROGRAM_CHANNELS

def rectify(x): # torch 1d to numpy 2d
    return x.detach().numpy().reshape(SPECTROGRAM_SIZE_F, SPECTROGRAM_SIZE_T)


SQUASH_SIZE = 512 #150

class SpeechModel(torch.nn.Module):
    def __init__(self):
        super(SpeechModel, self).__init__()
        self.m0 = nn.Linear(FEATURE_SIZE, SQUASH_SIZE, bias=True)
        self.m1 = nn.Linear(SQUASH_SIZE, FEATURE_SIZE, bias=False)

    def encode(self, x):
        return torch.relu(self.m0(x.view(-1, FEATURE_SIZE)))
    
    def decode(self, y):
        return (self.m1(y) ** 2).view(-1, FEATURE_SIZE)

    def forward(self, x):
        y = self.encode(x)
        x = self.decode(y)
        return x

clips = glob.glob('clip/clip-*')

def get_clip(name, distortion=1):
    rate, data = scipy.io.wavfile.read(name)
    assert rate == RATE
    data = scipy.signal.resample(data, int(len(data) / distortion))
    data = data[:RATE] / np.max(np.abs(data))
    data = np.pad(data, (0, RATE - len(data)))
    S = stft(data)
    return torch.tensor(np.abs(S)).reshape(FEATURE_SIZE)

def playback(name, distortion=1):
    print('Playing %s @ %f' % (name, distortion))
    rate, data = scipy.io.wavfile.read(name)
    assert rate == RATE
    data = scipy.signal.resample(data, int(len(data) / distortion))
    data = data[:RATE] / np.max(np.abs(data))
    data = np.pad(data, (0, RATE - len(data)))
    from IPython import display as ipd
    ipd.display(ipd.Audio(data, rate=RATE))